package com.jplus.plugins.mybatis.plugin;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Properties;

import org.apache.ibatis.builder.StaticSqlSource;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.executor.parameter.ParameterHandler;
import org.apache.ibatis.logging.Log;
import org.apache.ibatis.logging.jdbc.ConnectionLogger;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.apache.ibatis.transaction.Transaction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.jplus.framework.util.FormatUtil;

/**
 * Mybatis分页插件<br>
 * 查询时把List放入参数page中并返回<br/>
 * 使用说明：
 * 
 * 
 * @author Yuanqy
 */
@Intercepts({ @Signature(type = Executor.class, method = "query", args = { MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class }) })
public class MybatisInterceptor_Page implements Interceptor {
	private Logger log=LoggerFactory.getLogger(getClass());
	private static ThreadLocal<String> DBtype = new ThreadLocal<String>();
	private String defDBType;
	// 用户提供分页计算条数后缀，默认是_count
	private String subStr = "_count";
	// 存储所有语句名称
	final HashMap<String, String> map_statement = new HashMap<String, String>();

	/**
	 * 设置匹配分页总数sql后缀，如果有这个后缀的查询sql.优先执行。
	 */
	public void SetSubCount(String subStr) {
		this.subStr = subStr;
	}

	public static void SetDBtype(String tempDBType) {
		DBtype.set(tempDBType);
	}

	public String getDBtype() {
		return FormatUtil.isEmpty(DBtype.get()) ? defDBType : DBtype.get();
	}

	@SuppressWarnings("unused")
	private MybatisInterceptor_Page() {
		// 构造方法私有化
	}

	/**
	 * 设置数据库默认 类型<br>
	 * 中途想修改类型，请使用：MybatisInterceptor_Page.DBtype=xxx <br>
	 * 
	 * @param defDBType
	 *            ="oracle" or "mysql"
	 */
	public MybatisInterceptor_Page(String defDBType) {
		this.defDBType = (defDBType);
	}

	/**
	 * 获取所有statement语句的名称
	 * 
	 * @param configuration
	 */
	protected synchronized void initStatementMap(Configuration configuration) {
		if (!map_statement.isEmpty()) {
			return;
		}
		Collection<String> statements = configuration.getMappedStatementNames();
		for (Iterator<String> iter = statements.iterator(); iter.hasNext();) {
			String element = iter.next();
			map_statement.put(element, element);
		}
	}

	/**
	 * 获取数据库连接
	 */
	protected Connection getConnection(Transaction transaction, Log statementLog) throws SQLException {
		Connection connection = transaction.getConnection();
		if (statementLog.isDebugEnabled()) {
			return ConnectionLogger.newInstance(connection, statementLog);
		} else {
			return connection;
		}
	}

	@Override
	public Object intercept(Invocation invocation) throws Throwable {
		Object parameter = invocation.getArgs()[1];
		JPage jPage = seekPage(parameter);
		if (jPage == null) {
			return invocation.proceed();
		} else {
			return handlePaging(invocation, parameter, jPage);
		}

	}

	/**
	 * 处理分页的情况
	 * <p>
	 * 
	 * @param invocation
	 * @param parameter
	 * @param jPage
	 * @throws SQLException
	 */
	@SuppressWarnings("rawtypes")
	protected List handlePaging(Invocation invocation, Object parameter, JPage jPage) throws Exception {
		MappedStatement mappedStatement = (MappedStatement) invocation.getArgs()[0];
		Configuration configuration = mappedStatement.getConfiguration();
		if (map_statement.isEmpty()) {
			initStatementMap(configuration);
		}
		BoundSql boundSql = mappedStatement.getBoundSql(parameter);
		// 查询结果集
		StaticSqlSource sqlsource = new StaticSqlSource(configuration, buildPageSql(boundSql.getSql(), jPage), boundSql.getParameterMappings());
		MappedStatement.Builder builder = new MappedStatement.Builder(configuration, mappedStatement.getId(), sqlsource, SqlCommandType.SELECT);
		builder.resultMaps(mappedStatement.getResultMaps()).resultSetType(mappedStatement.getResultSetType()).statementType(mappedStatement.getStatementType());
		MappedStatement query_statement = builder.build();

		List data = (List) exeQuery(invocation, query_statement);
		// 设置到page对象
		jPage.setData(data);
		jPage.setCount(getTotalSize(invocation, configuration, mappedStatement, boundSql, parameter));

		return data;
	}

	/**
	 * 根据提供的语句执行查询操作
	 *
	 * @param invocation
	 * @param query_statement
	 * @return
	 * @throws Exception
	 */
	protected Object exeQuery(Invocation invocation, MappedStatement query_statement) throws Exception {
		Object[] args = invocation.getArgs();
		return invocation.getMethod().invoke(invocation.getTarget(), new Object[] { query_statement, args[1], args[2], args[3] });
	}

	/**
	 * 获取总记录数量
	 * <p>
	 *
	 * @param configuration
	 * @param mappedStatement
	 * @param sql
	 * @param parameter
	 * @return
	 * @throws SQLException
	 */
	@SuppressWarnings("rawtypes")
	protected int getTotalSize(Invocation invocation, Configuration configuration, MappedStatement mappedStatement, BoundSql boundSql, Object parameter)
			throws Exception {
		String count_id = mappedStatement.getId() + subStr;
		int totalSize = 0;
		if (map_statement.containsKey(count_id)) {
			// 优先查找能统计条数的sql
			List data = (List) exeQuery(invocation, mappedStatement.getConfiguration().getMappedStatement(count_id));
			if (data.size() > 0) {
				totalSize = Integer.parseInt(data.get(0).toString());
			}
		} else {
			Executor exe = (Executor) invocation.getTarget();
			Connection connection = getConnection(exe.getTransaction(), mappedStatement.getStatementLog());
			String countSql = getCountSql(boundSql.getSql());
			log.debug("[分页插件]：" + countSql);
			totalSize = getTotalSize(configuration, mappedStatement, boundSql, countSql, connection, parameter);
		}
		return totalSize;
	}

	/**
	 * 拼接查询sql,加入分页
	 *
	 * @param sql
	 * @param jPage
	 */
	protected String buildPageSql(String sql, JPage jPage) {
		StringBuilder sb = new StringBuilder(sql.length() + 100);
		if (getDBtype().equalsIgnoreCase("oracle")) {
			String beginrow = String.valueOf((jPage.getPage() - 1) * jPage.getPageSize());
			String endrow = String.valueOf(jPage.getPage() * jPage.getPageSize());
			sb.append("select * from ( select temp.*, rownum row_id from ( ");
			sb.append(sql);
			sb.append(" ) temp where rownum <= ").append(endrow);
			sb.append(") where row_id > ").append(beginrow);
		}
		if (getDBtype().equalsIgnoreCase("mysql")) {
			sb.append(sql);
			sb.append(" limit ").append(jPage.getStartNo() - 1).append(",").append(jPage.getPageSize());
		}
		return sb.toString();
	}

	/**
	 * 拼接获取条数的sql语句
	 * <p>
	 *
	 * @param sqlPrimary
	 */
	protected String getCountSql(String sqlPrimary) {
		if (getDBtype().equalsIgnoreCase("oracle")) {
			return "SELECT COUNT(1) AS cnt FROM (" + sqlPrimary + ") t";
		}
		if (getDBtype().equalsIgnoreCase("mysql")) {
			return "SELECT COUNT(1) AS cnt FROM (" + sqlPrimary + ") t";
		}
		return null;
	}

	/**
	 * 计算总条数
	 * <p>
	 *
	 * @param parameterObj
	 * @param countSql
	 * @param connection
	 * @return
	 */
	protected int getTotalSize(Configuration configuration, MappedStatement mappedStatement, BoundSql boundSql, String countSql, Connection connection,
			Object parameter) throws SQLException {
		PreparedStatement stmt = null;
		ResultSet rs = null;
		int totalSize = 0;
		try {
			ParameterHandler handler = configuration.newParameterHandler(mappedStatement, parameter, boundSql);
			stmt = connection.prepareStatement(countSql);
			handler.setParameters(stmt);
			rs = stmt.executeQuery();
			if (rs.next()) {
				totalSize = rs.getInt(1);
			}
		} catch (SQLException e) {
			throw e;
		} finally {
			if (rs != null) {
				rs.close();
				rs = null;
			}
			if (stmt != null) {
				stmt.close();
				stmt = null;
			}
		}
		return totalSize;
	}

	/**
	 * 寻找page对象
	 * <p>
	 *
	 * @param parameter
	 */
	@SuppressWarnings("rawtypes")
	protected JPage seekPage(Object parameter) {
		JPage jPage = null;
		if (parameter == null) {
			return null;
		}
		if (parameter instanceof JPage) {
			jPage = (JPage) parameter;
		} else if (parameter instanceof Map) {
			Map map = (Map) parameter;
			for (Object arg : map.values()) {
				if (arg instanceof JPage) {
					jPage = (JPage) arg;
				}
			}
		}
		return jPage;
	}

	@Override
	public Object plugin(Object target) {
		return Plugin.wrap(target, this);
	}

	@Override
	public void setProperties(Properties properties) {
	}
}